import random
from torchvision.transforms import functional as F
from PIL import Image


class Compose(object):
    """组合多个transform函数"""

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


class ToTensor(object):
    """将PIL图像转为Tensor"""

    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target


class RandomHorizontalFlip(object):
    """随机水平翻转图像以及bboxes"""

    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = image.flip(-1)  # 水平翻转图片
            bbox = target["boxes"]
            # bbox: xmin, ymin, xmax, ymax
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]  # 翻转对应bbox坐标信息
            target["boxes"] = bbox
            if "masks" in target:
                target["masks"] = target["masks"].flip(-1)
        return image, target


class UniformSize(object):
    """将图像的尺寸统一至[512,512]"""

    def __call__(self, image, target):
        iw, ih = image.size
        size = [512, 512]
        w, h = size
        scale = min(w / iw, h / ih)
        # 缩放因子
        nw = int(iw * scale)
        nh = int(ih * scale)
        image = image.resize((nw, nh), Image.BICUBIC)
        new_image = Image.new('RGB', size, (0, 0, 0))
        new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))  # 完成images缩放
        bbox = target["boxes"]  # 获取bboxes
        # bbox: xmin, ymin, xmax, ymax
        for i in range(4): bbox[i] = bbox[i] * scale  # 与缩放因子相乘
        top = (h - nh) // 2
        bbox[1] = bbox[1] + top
        bbox[3] = bbox[3] + top
        bottom = h - nh - top
        left = (w - nw) // 2
        bbox[0] = bbox[0] + left
        bbox[2] = bbox[2] + left
        target["boxes"] = bbox
        # mask进行缩放（与image一样进行resize）
        mask = target["masks"].resize((nw, nh), Image.BICUBIC)
        target["masks"] = Image.new('L', size, (0, 0, 0))
        target["masks"].paste(mask, ((w - nw) // 2, (h - nh) // 2))
        return image, target
